{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "088542f0",
   "metadata": {},
   "source": [
    "# Lesson 2: Datasets For Reinforcement Learning Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e08cfaf3",
   "metadata": {},
   "source": [
    "### Loading and exploring the datasets\n",
    "\n",
    "\"Reinforcement Learning from Human Feedback\" **(RLHF)** requires the following datasets:\n",
    "- Preference dataset\n",
    "  - Input prompt, candidate response 0, candidate response 1, choice (candidate 0 or 1)\n",
    "- Prompt dataset\n",
    "  - Input prompt only, no response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "236bce0e",
   "metadata": {},
   "source": [
    "#### Preference dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3048b1b5-a01c-4c89-af5a-ce8f47d13218",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "preference_dataset_path = 'sample_preference.jsonl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7822c372-b9d0-46c0-86c3-dc38508262e3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb35514-6e2b-433c-a6f3-9411846fd2a6",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "preference_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35bcc873-cd6a-4108-9c4a-47b7d27faab7",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "with open(preference_dataset_path) as f:\n",
    "    for line in f:\n",
    "        preference_data.append(json.loads(line))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b2290d9",
   "metadata": {},
   "source": [
    "- Print out to explore the preference dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57964e85-d4bd-4b35-b437-e51a4493eb33",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "sample_1 = preference_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9bff834-4c5e-46c4-8887-d3a29cc8de86",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(type(sample_1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63b473d0-bf39-4f8b-99e5-d6f795f33b83",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# This dictionary has four keys\n",
    "print(sample_1.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21d524bb",
   "metadata": {},
   "source": [
    "- Key: 'input_test' is a prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "630d6337-0e6c-4ccb-987a-7d73de91c65a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "sample_1['input_text']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb97174-bb2b-400a-b76a-81428063b76a",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "# Try with another examples from the list, and discover that all data end the same way\n",
    "preference_data[2]['input_text'][-50:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "78232d32",
   "metadata": {},
   "source": [
    "- Print 'candidate_0' and 'candidate_1', these are the completions for the same prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67ab282-3495-4aab-a43a-309978e03529",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print(f\"candidate_0:\\n{sample_1.get('candidate_0')}\\n\")\n",
    "print(f\"candidate_1:\\n{sample_1.get('candidate_1')}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55d3cc61",
   "metadata": {},
   "source": [
    "- Print 'choice', this is the human labeler's preference for the results completions (candidate_0 and candidate_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dae5b1cb-5411-462c-a122-bbb6264e4111",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(f\"choice: {sample_1.get('choice')}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8eda6787",
   "metadata": {},
   "source": [
    "#### Prompt dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c2785a-87df-4bc7-adb8-ccbdfff7b819",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "prompt_dataset_path = 'sample_prompt.jsonl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3fd8ef-c9f3-4bc3-9404-cbe91fcae150",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "prompt_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0e47624-472f-4f20-8f02-219b38e166ee",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "with open(prompt_dataset_path) as f:\n",
    "    for line in f:\n",
    "        prompt_data.append(json.loads(line))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1c90028-5da8-435d-a203-92a661154598",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Check how many prompts there are in this dataset\n",
    "len(prompt_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e08fd3a7",
   "metadata": {},
   "source": [
    "**Note**: It is important that the prompts in both datasets, the preference and the prompt, come from the same distribution. \n",
    "\n",
    "For this lesson, all the prompts come from the same dataset of [Reddit posts](https://github.com/openai/summarize-from-feedback)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6da85e-5b6d-4d68-bb8e-70716205f28a",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Function to print the information in the prompt dataset with a better visualization\n",
    "def print_d(d):\n",
    "    for key, val in d.items():        \n",
    "        print(f\"key:{key}\\nval:{val}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e665d332-38a2-4b8c-be5a-6e4c6c3392b1",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print_d(prompt_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9825d214-13a1-41c0-8280-b584d2f3bbd0",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Try with another prompt from the list \n",
    "print_d(prompt_data[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a1d719-7c4b-4269-a3b7-d218ebfc3cf2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
